# Kernel: presistent_birnn

# Copyright 2016 Nervana Systems Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


<CONSTANT_MAPPING>
    addr_zero : 4x<(64*48)>

    gridDimA : c[0x0][0x14]
    gridDimB : c[0x0][0x18]

    param_h[0]        : c[0x0][0x140]
    param_h[1]        : c[0x0][0x144]
    param_hprev[0]    : c[0x0][0x148]
    param_hprev[1]    : c[0x0][0x14c]
    param_bias[0]     : c[0x0][0x150]
    param_bias[1]     : c[0x0][0x154]
    param_w[0]        : c[0x0][0x158]
    param_w[1]        : c[0x0][0x15c]
    param_lockAddr[0] : c[0x0][0x160]
    param_lockAddr[1] : c[0x0][0x164]
    param_ldh         : c[0x0][0x168]
    param_ldw         : c[0x0][0x16c]
    param_bsz         : c[0x0][0x170]
    param_seqLength   : c[0x0][0x174]
    param_numBlks     : c[0x0][0x178]
    param_rowSize     : c[0x0][0x17c]
    param_reverse     : c[0x0][0x180]
    param_reluclip    : c[0x0][0x184]
</CONSTANT_MAPPING>

<REGISTER_MAPPING>

      0-215 : weight<000-215>
    216-227 : accum<00-11>
    228-229 : timeStep, biasValue
    230-232 : warpTid, rowOffset, tid

    233     : bid

    236-243 : wAddr0r<0-1>, wAddr1r<0-1>, wAddr2r<0-1>, biasAddr<0-1>
    244-254 ~ ldw, wRow, warpTid4, loadRow, warpIndex, storeWeights, loadWeights, rowSize

    233     : hOffset
    233     : ldh
    234-239 : hprevAddr<0-1>, loadBuffer<0-3>
    240-251 : hidden0r<0-3>, hidden1r<0-3>, hidden2r<0-3>
    252-254 ~ loadHiddens, storeHiddens, loadIndex

    240-251 : peerR0V<0-3>, peerR1V<0-3>, peerR2V<0-3>

    240-249 : output<0-3>, hAddr<0-1>, lockAddr<0-1>, expectVal, setVal
    250-254 ~ storeIndex, hRow, predSave, lockVal, reluclip

</REGISTER_MAPPING>

//Get tid/block id
--:-:1:-:1      S2R tid, SR_TID.X;
--:-:2:-:1      S2R bid, SR_CTAID.X;

//Store zeros at addr_zero
--:-:-:-:1      STS.128 [addr_zero], RZ;

<SCHEDULE_BLOCK>
--:-:-:-:1      MOV     ldw,       param_ldw;
--:-:-:-:1      MOV     rowSize,   param_rowSize;

//timeStep = (param_reverse == 0) ? 0 : param_seqLength
--:-:-:-:1      ISETP.EQ.AND P2, PT, RZ, param_reverse, PT;
--:-:-:-:1      SEL timeStep, RZ, param_seqLength, P2;
--:-:-:-:1 @!P2 IADD timeStep, timeStep, -1;

//warpIndex = threadIdx.x >> 5
01:-:-:-:1      SHR.U32 warpIndex, tid, 5;

//warpTid = threadIdx.x & 0x1f
01:-:-:-:1      LOP.AND warpTid,   tid, 0x1f;

//rowOffset = ((blockIdx.x << 3) + warp_index) * 6
02:-:-:-:1      SHL     rowOffset, bid,       3;
--:-:-:-:1      IADD    rowOffset, rowOffset, warpIndex;
--:-:-:-:1      XMAD    rowOffset, rowOffset, 6, RZ;

//if(warp_tid > 15) rowOffset += 3
--:-:-:-:1      ISETP.GT.AND P1, PT, warpTid, 15, PT;
--:-:-:-:1  @P1 IADD     rowOffset, rowOffset, 3;

//warpTid = warpTid & 0x0f
--:-:-:-:1      LOP.AND  warpTid, warpTid, 0x0f;
--:-:-:-:1      ISETP.LT.AND P0, PT, warpTid, 3, PT;

//warpTid4 = warpTid << 2
--:-:-:-:1      SHL      warpTid4, warpTid, 2;

//storeWeights = ((P1) ? (warpTid4 + 3*64) : warpTid4) << 2
//loadWeights = ((P1) ? (warpTid + 3*64) : warpTid) << 2
--:-:-:-:1  @P1 MOV      loadWeights, 3;
--:-:-:-:1 @!P1 MOV      loadWeights, RZ;

--:-:-:-:1      XMAD     loadWeights, warpIndex, 6, loadWeights;
--:-:-:-:1      SHL      loadWeights, loadWeights, 6;

--:-:-:-:1      IADD     storeWeights, loadWeights, warpTid4;
--:-:-:-:1      IADD     loadWeights, loadWeights, warpTid;
--:-:-:-:1      SHL      storeWeights, storeWeights, 2;
--:-:-:-:1      SHL      loadWeights, loadWeights, 2;

//wRow = rowOffset * ldw + warpTid
--:-:-:-:1      XMAD     wRow, rowOffset, ldw, warpTid4;

//wAddr0r = &w[wRow]
--:-:-:-:1      LEA      wAddr0r0.CC, wRow, param_w[0],     2;
--:-:-:-:1      LEA.HI.X wAddr0r1,    wRow, param_w[1], RZ, 2;

//ldw = ldw << 2
--:-:-:-:1      SHL      ldw,  ldw,       2;

//wAddr1r = wAddr0r + ldw
--:-:-:-:1      IADD     wAddr1r0.CC, wAddr0r0, ldw;
--:-:-:-:1      IADD.X   wAddr1r1,    wAddr0r1, RZ;

//wAddr2r = wAddr2r + ldw
--:-:-:-:1      IADD     wAddr2r0.CC, wAddr1r0, ldw;
--:-:-:-:1      IADD.X   wAddr2r1,    wAddr1r1, RZ;

//Compute row loading predicates
--:-:-:-:1      ISETP.LT.AND P1, PT, warpTid4, rowSize, PT;
--:-:-:-:1      ISETP.LT.AND P3, PT, rowOffset, rowSize, P1;
--:-:-:-:1      IADD     rowSize, rowSize, -1;
--:-:-:-:1      ISETP.LT.AND P4, PT, rowOffset, rowSize, P1;
--:-:-:-:1      IADD     rowSize, rowSize, -1;
--:-:-:-:1      ISETP.LT.AND P5, PT, rowOffset, rowSize, P1;
</SCHEDULE_BLOCK>

--:-:-:Y:c      NOP;

//Load weights to registers
<CODE>
    my $out;
    my $regId = 0;
    my $rowsize = 1152;

    for (my $col=0; $col < $rowsize; $col += 64)
    {
        $out .= "--:-:-:-:1      IADD warpTid4, warpTid4, 64;\n";

        #Use vector loads from weight matrix
        $regId = $col / 16;
        $out .= sprintf "--:-:1:-:1  \@P3 LDG.E.128 weight%03d, [wAddr0r + 4x<%d>];\n", $regId, $col;
        $out .= sprintf "--:-:1:-:1 \@!P3 LDS.U.128 weight%03d, [addr_zero];\n", $regId;
        $regId = $col / 16 + 72;
        $out .= sprintf "--:-:2:-:1  \@P4 LDG.E.128 weight%03d, [wAddr1r + 4x<%d>];\n", $regId, $col;
        $out .= sprintf "--:-:2:-:1 \@!P4 LDS.U.128 weight%03d, [addr_zero];\n", $regId;
        $regId = $col / 16 + 144;
        $out .= sprintf "--:-:3:-:1  \@P5 LDG.E.128 weight%03d, [wAddr2r + 4x<%d>];\n", $regId, $col;
        $out .= sprintf "--:-:3:-:1 \@!P5 LDS.U.128 weight%03d, [addr_zero];\n", $regId;

        $out .= "--:-:-:-:1      ISETP.LT.AND P3, PT, warpTid4, rowSize, P3;\n";
        $out .= "--:-:-:-:1      ISETP.LT.AND P4, PT, warpTid4, rowSize, P4;\n";
        $out .= "--:-:-:-:1      ISETP.LT.AND P5, PT, warpTid4, rowSize, P5;\n";

        #Store weights into shared memory
        if ($col > 0)
        {
            $out .= "--:-:-:-:5      BAR.SYNC 0;\n\n";
        }

        $regId = $col / 16;
        $out .= sprintf "01:-:-:-:1      STS.U.128 [storeWeights], weight%03d;\n", $regId;
        $regId = $col / 16 + 72;
        $out .= sprintf "02:-:-:-:1      STS.U.128 [storeWeights + 4x<64>], weight%03d;\n", $regId;
        $regId = $col / 16 + 144;
        $out .= sprintf "04:-:-:-:1      STS.U.128 [storeWeights + 4x<128>], weight%03d;\n", $regId;

        #Load each weight from shared mem
        $out .= "--:-:-:-:5      BAR.SYNC 0;\n\n";

        foreach my $shared_col (0 .. 3)
        {
            foreach my $row (0 .. 2)
            {
                my $control;

                if (($col + 64) >= $rowsize && $row == 2 && $shared_col == 3)
                {
                    $control = "--:1:6:-:2";
                }
                else
                {
                    $control = "--:-:-:-:1";
                }

                $regId = ($row * 72) + ($col / 16) + $shared_col;
                my $shared_offset = ($row * 64) + ($shared_col * 16);
                $out .= sprintf "%s      LDS.U weight%03d, [loadWeights + 4x<%d>];\n", $control, $regId, $shared_offset;
            }
        }
    }

    $out .= "--:-:-:-:5      BAR.SYNC 0;\n\n";

    return $out;

</CODE>

//Conditional load of bias
<SCHEDULE_BLOCK>
01:-:-:-:1      IADD     loadRow,      rowOffset, warpTid;
--:-:-:-:1      ISETP.LT.AND P0, PT, loadRow, param_rowSize, P0;
--:-:-:-:1      LEA      biasAddr0.CC, loadRow,   param_bias[0],     2;
--:-:-:-:1      LEA.HI.X biasAddr1,    loadRow,   param_bias[1], RZ, 2;
--:-:-:-:1  @P0 LDG.E    biasValue,    [biasAddr];
--:-:-:-:1 @!P0 MOV      biasValue,    RZ;
</SCHEDULE_BLOCK>

//Predicates for store code
--:-:-:-:1      ISETP.EQ.AND P2, PT, warpTid, 0, PT;
--:-:-:-:1      ISETP.EQ.AND P3, PT, warpTid, 1, PT;
--:-:-:-:1      ISETP.EQ.AND P4, PT, warpTid, 2, PT;

UNROLLING_LOOP:
<SCHEDULE_BLOCK>
//Prime inner product loop by loading first rows of hprev
--:-:-:-:1      MOV loadIndex,    tid;

//storeHiddens = tid << 4
--:-:-:-:1      SHL storeHiddens, tid, 4;
--:-:-:-:1      SHL loadHiddens, warpTid, 4;

//hprevAddr = &h_prev[timeStep * ldh + loadIndex]
--:-:-:-:1      XMAD     hOffset,        loadIndex, param_ldh,      timeStep;
--:-:-:-:1      LEA      hprevAddr0.CC,  hOffset,   param_hprev[0],     4;
--:-:-:-:2      LEA.HI.X hprevAddr1,     hOffset,   param_hprev[1], RZ, 4;

//loadBuffer = *hprevAddr
--:-:-:-:1      ISETP.LT.AND P1, PT, loadIndex, param_rowSize, PT;
--:5:1:-:2  @P1 LDG.E.CI.128 loadBuffer, [hprevAddr];
--:5:1:-:2 @!P1 LDS.U.128    loadBuffer, [addr_zero];

//ldh = param_ldh << 12
--:-:-:-:1      MOV ldh, param_ldh;
--:-:-:-:1      SHL ldh, ldh, 12;
</SCHEDULE_BLOCK>

//Initialize all accumulation registers to 0
<CODE>
    return join '', map sprintf("--:-:-:-:1      LDS.U.128 accum%02d, [addr_zero];\n", $_ * 4), 0..2;
</CODE>

//Update load index and load address
--:-:-:-:6      IADD loadIndex, loadIndex, 256;
--:-:-:-:1      ISETP.LT.AND P1, PT, loadIndex, param_rowSize, PT;
10:-:-:-:6      IADD   hprevAddr0.CC, hprevAddr0, ldh;
--:-:-:-:6      IADD.X hprevAddr1,    hprevAddr1, RZ;

01:-:-:-:1      STS.U.128 [storeHiddens], loadBuffer;

//Unrolled GEMM loop
<CODE>
    our @top;

    my $out = join '', @top;

    my $rowsize = 1152;
    my $weight_index = 0;

    my $wait_flag = 2;
    my $set_flag = 4;
    my $read_buffer = 0;
    my $write_buffer = 2;

    for (my $k=0; $k < $rowsize; $k+=256)
    {
        if ($k == 0)
        {
            $out .= "--:6:1:-:1  \@P1 LDG.E.CI.128 loadBuffer, [hprevAddr];\n";
            $out .= "--:-:1:-:1 \@!P1 LDS.U.128    loadBuffer, [addr_zero];\n\n";
            $out .= "--:-:-:-:5      BAR.SYNC 0;\n\n";
            $out .= "--:-:2:-:1      LDS.U.128 hidden0r, [loadHiddens];\n";
            $out .= "--:-:3:-:1      LDS.U.128 hidden1r, [loadHiddens + 4x<4*16>];\n\n";
        }
        $out .= "--:-:-:-:1      LOP.XOR storeHiddens, storeHiddens, 4096;\n";

        foreach my $shared_row (0 .. 15)
        {
            if($weight_index < 72)
            {
                if ($shared_row < 14 && ($k + (16 * ($shared_row + 2))) < $rowsize)
                {
                    my $read_bar = "-";
                    if ($shared_row == 13 && ($k + 256) < $rowsize)
                    {
                        $read_bar = "5";
                    }
                    $out .= sprintf "--:%s:%d:-:1      LDS.U.128 hidden%dr, [loadHiddens + 4x<4*%d>];\n", $read_bar, $set_flag, $write_buffer, (16 * ($shared_row + 2));
                }

                if ($shared_row == 11)
                {
                    $out .= "--:-:-:-:1      IADD loadIndex, loadIndex, 256;\n";
                    $out .= "20:-:-:-:1      IADD hprevAddr0.CC, hprevAddr0, ldh;\n";
                }

                if ($shared_row == 12)
                {
                    $out .= "--:-:-:-:1      ISETP.LT.AND P1, PT, loadIndex, param_rowSize, PT;\n";
                    $out .= "--:-:-:-:1      IADD.X hprevAddr1,    hprevAddr1, RZ;\n";
                }

                if ($shared_row == 13)
                {
                    $out .= "01:-:-:-:1      STS.U.128 [storeHiddens], loadBuffer;\n";

                    if (($k + 512) < $rowsize)
                    {
                        $out .= "--:6:1:-:1  \@P1 LDG.E.CI.128 loadBuffer, [hprevAddr];\n";
                        $out .= "--:-:1:-:1 \@!P1 LDS.U.128    loadBuffer, [addr_zero];\n\n";
                    }
                    else
                    {
                        $out .= "--:-:-:-:6      IADD     hOffset,        rowOffset, warpTid;\n";
                        $out .= "--:-:-:-:6      XMAD     hOffset,        hOffset,   param_ldh,  timeStep;\n";
                        $out .= "--:-:-:-:6      LEA      hprevAddr0.CC,  hOffset,   param_h[0],      4;\n";
                        $out .= "--:-:-:-:2      LEA.HI.X hprevAddr1,     hOffset,   param_h[1], RZ, 4;\n";
                        $out .= "--:-:6:-:1 \@P0 LDG.E.CI.128 loadBuffer, [hprevAddr];\n\n";
                    }
                }

                if ($shared_row == 14)
                {
                    $out .= "10:-:-:-:1      LOP.XOR loadHiddens, loadHiddens, 4096;\n";
                    $out .= "--:-:-:-:5      BAR.SYNC 0;\n\n";
                    $out .= sprintf "--:-:%d:-:1      LDS.U.128 hidden%dr, [loadHiddens];\n", $set_flag, $write_buffer;
                }

                if ($shared_row == 15)
                {
                    $out .= sprintf "--:-:%d:-:1      LDS.U.128 hidden%dr, [loadHiddens + 4x<4*16>];\n\n", $set_flag, $write_buffer;
                }

                foreach my $row (0 .. 2)
                {
                    my $weight = ($row * 72) + $weight_index;

                    foreach my $col (0 .. 3)
                    {
                        my $accum = ($row * 4) + $col;
                        my $wait = "--";
                        my $stall = 1;
                        if ($accum == 0)
                        {
                            if ($weight_index == 0)
                            {
                                $wait = sprintf "%02x", (0x20 | (1 << ($wait_flag - 1)));
                            }
                            else
                            {
                                $wait = sprintf "%02x", (1 << ($wait_flag - 1));
                            }
                        }

                        if ($row == 2 && $col == 3)
                        {
                            if ($shared_row < 13 && ($k + (16 * ($shared_row + 3))) < $rowsize)
                            {
                                $stall = 0;
                            }
                            elsif ($shared_row == 14 && ($k + 256) < $rowsize)
                            {
                                $stall = 0;
                            }
                        }

                        $out .= sprintf "%s:-:-:-:%d      FFMA accum%02d, weight%03d, hidden%dr%d, accum%02d;\n", $wait, $stall, $accum, $weight, $read_buffer, $col, $accum;
                    }
                }

                $weight_index++;
            }

            $wait_flag += 1;
            $set_flag += 1;
            $read_buffer += 1;
            $write_buffer += 1;
            if($wait_flag == 5)
            {
                $wait_flag = 2;
            }
            if($set_flag == 5)
            {
                $set_flag = 2;
            }
            if($read_buffer == 3)
            {
                $read_buffer = 0;
            }
            if($write_buffer == 3)
            {
                $write_buffer = 0;
            }
        }
    }

    return $out;
</CODE>

//Reduction between threads
--:-:-:-:1      SHFL.BFLY PT, peerR0V0, accum00, 1, 0x1f;
--:-:-:-:1      SHFL.BFLY PT, peerR1V0, accum04, 1, 0x1f;
--:-:1:-:1      SHFL.BFLY PT, peerR2V0, accum08, 1, 0x1f;

--:-:-:-:1      SHFL.BFLY PT, peerR0V1, accum01, 1, 0x1f;
--:-:-:-:1      SHFL.BFLY PT, peerR1V1, accum05, 1, 0x1f;
--:-:2:-:1      SHFL.BFLY PT, peerR2V1, accum09, 1, 0x1f;

--:-:-:-:1      SHFL.BFLY PT, peerR0V2, accum02, 1, 0x1f;
--:-:-:-:1      SHFL.BFLY PT, peerR1V2, accum06, 1, 0x1f;
--:-:3:-:1      SHFL.BFLY PT, peerR2V2, accum10, 1, 0x1f;

--:-:-:-:1      SHFL.BFLY PT, peerR0V3, accum03, 1, 0x1f;
--:-:-:-:1      SHFL.BFLY PT, peerR1V3, accum07, 1, 0x1f;
--:-:4:-:1      SHFL.BFLY PT, peerR2V3, accum11, 1, 0x1f;

01:-:-:-:1      FADD accum00, accum00, peerR0V0;
--:-:-:-:1      FADD accum04, accum04, peerR1V0;
--:-:-:-:1      FADD accum08, accum08, peerR2V0;

02:-:-:-:1      FADD accum01, accum01, peerR0V1;
--:-:-:-:1      FADD accum05, accum05, peerR1V1;
--:-:-:-:1      FADD accum09, accum09, peerR2V1;

--:-:-:-:1      SHFL.BFLY PT, peerR0V0, accum00, 2, 0x1f;
--:-:-:-:1      SHFL.BFLY PT, peerR1V0, accum04, 2, 0x1f;
--:-:1:-:1      SHFL.BFLY PT, peerR2V0, accum08, 2, 0x1f;

--:-:-:-:1      SHFL.BFLY PT, peerR0V1, accum01, 2, 0x1f;
--:-:-:-:1      SHFL.BFLY PT, peerR1V1, accum05, 2, 0x1f;
--:-:2:-:1      SHFL.BFLY PT, peerR2V1, accum09, 2, 0x1f;

04:-:-:-:1      FADD accum02, accum02, peerR0V2;
--:-:-:-:1      FADD accum06, accum06, peerR1V2;
--:-:-:-:1      FADD accum10, accum10, peerR2V2;

08:-:-:-:1      FADD accum03, accum03, peerR0V3;
--:-:-:-:1      FADD accum07, accum07, peerR1V3;
--:-:-:-:1      FADD accum11, accum11, peerR2V3;

--:-:-:-:1      SHFL.BFLY PT, peerR0V2, accum02, 2, 0x1f;
--:-:-:-:1      SHFL.BFLY PT, peerR1V2, accum06, 2, 0x1f;
--:-:3:-:1      SHFL.BFLY PT, peerR2V2, accum10, 2, 0x1f;

--:-:-:-:1      SHFL.BFLY PT, peerR0V3, accum03, 2, 0x1f;
--:-:-:-:1      SHFL.BFLY PT, peerR1V3, accum07, 2, 0x1f;
--:-:4:-:1      SHFL.BFLY PT, peerR2V3, accum11, 2, 0x1f;

01:-:-:-:1      FADD accum00, accum00, peerR0V0;
--:-:-:-:1      FADD accum04, accum04, peerR1V0;
--:-:-:-:1      FADD accum08, accum08, peerR2V0;

02:-:-:-:1      FADD accum01, accum01, peerR0V1;
--:-:-:-:1      FADD accum05, accum05, peerR1V1;
--:-:-:-:1      FADD accum09, accum09, peerR2V1;

--:-:-:-:1      SHFL.BFLY PT, peerR0V0, accum00, 4, 0x1f;
--:-:-:-:1      SHFL.BFLY PT, peerR1V0, accum04, 4, 0x1f;
--:-:1:-:1      SHFL.BFLY PT, peerR2V0, accum08, 4, 0x1f;

--:-:-:-:1      SHFL.BFLY PT, peerR0V1, accum01, 4, 0x1f;
--:-:-:-:1      SHFL.BFLY PT, peerR1V1, accum05, 4, 0x1f;
--:-:2:-:1      SHFL.BFLY PT, peerR2V1, accum09, 4, 0x1f;

04:-:-:-:1      FADD accum02, accum02, peerR0V2;
--:-:-:-:1      FADD accum06, accum06, peerR1V2;
--:-:-:-:1      FADD accum10, accum10, peerR2V2;

08:-:-:-:1      FADD accum03, accum03, peerR0V3;
--:-:-:-:1      FADD accum07, accum07, peerR1V3;
--:-:-:-:1      FADD accum11, accum11, peerR2V3;

--:-:-:-:1      SHFL.BFLY PT, peerR0V2, accum02, 4, 0x1f;
--:-:-:-:1      SHFL.BFLY PT, peerR1V2, accum06, 4, 0x1f;
--:-:3:-:1      SHFL.BFLY PT, peerR2V2, accum10, 4, 0x1f;

--:-:-:-:1      SHFL.BFLY PT, peerR0V3, accum03, 4, 0x1f;
--:-:-:-:1      SHFL.BFLY PT, peerR1V3, accum07, 4, 0x1f;
--:-:4:-:1      SHFL.BFLY PT, peerR2V3, accum11, 4, 0x1f;

01:-:-:-:1      FADD accum00, accum00, peerR0V0;
--:-:-:-:1      FADD accum04, accum04, peerR1V0;
--:-:-:-:1      FADD accum08, accum08, peerR2V0;

02:-:-:-:1      FADD accum01, accum01, peerR0V1;
--:-:-:-:1      FADD accum05, accum05, peerR1V1;
--:-:-:-:1      FADD accum09, accum09, peerR2V1;

--:-:-:-:1      SHFL.BFLY PT, peerR0V0, accum00, 8, 0x1f;
--:-:-:-:1      SHFL.BFLY PT, peerR1V0, accum04, 8, 0x1f;
--:-:1:-:1      SHFL.BFLY PT, peerR2V0, accum08, 8, 0x1f;

--:-:-:-:1      SHFL.BFLY PT, peerR0V1, accum01, 8, 0x1f;
--:-:-:-:1      SHFL.BFLY PT, peerR1V1, accum05, 8, 0x1f;
--:-:2:-:1      SHFL.BFLY PT, peerR2V1, accum09, 8, 0x1f;

04:-:-:-:1      FADD accum02, accum02, peerR0V2;
--:-:-:-:1      FADD accum06, accum06, peerR1V2;
--:-:-:-:1      FADD accum10, accum10, peerR2V2;

08:-:-:-:1      FADD accum03, accum03, peerR0V3;
--:-:-:-:1      FADD accum07, accum07, peerR1V3;
--:-:-:-:1      FADD accum11, accum11, peerR2V3;

--:-:-:-:1      SHFL.BFLY PT, peerR0V2, accum02, 8, 0x1f;
--:-:-:-:1      SHFL.BFLY PT, peerR1V2, accum06, 8, 0x1f;
--:-:3:-:1      SHFL.BFLY PT, peerR2V2, accum10, 8, 0x1f;

--:-:-:-:1      SHFL.BFLY PT, peerR0V3, accum03, 8, 0x1f;
--:-:-:-:1      SHFL.BFLY PT, peerR1V3, accum07, 8, 0x1f;
--:-:4:-:1      SHFL.BFLY PT, peerR2V3, accum11, 8, 0x1f;

01:-:-:-:1      FADD accum00, accum00, peerR0V0;
--:-:-:-:1      FADD accum04, accum04, peerR1V0;
--:-:-:-:1      FADD accum08, accum08, peerR2V0;

02:-:-:-:1      FADD accum01, accum01, peerR0V1;
--:-:-:-:1      FADD accum05, accum05, peerR1V1;
--:-:-:-:1      FADD accum09, accum09, peerR2V1;

04:-:-:-:1      FADD accum02, accum02, peerR0V2;
--:-:-:-:1      FADD accum06, accum06, peerR1V2;
--:-:-:-:1      FADD accum10, accum10, peerR2V2;

08:-:-:-:1      FADD accum03, accum03, peerR0V3;
--:-:-:-:1      FADD accum07, accum07, peerR1V3;
--:-:-:-:1      FADD accum11, accum11, peerR2V3;

//Compute store pointer
<SCHEDULE_BLOCK>
--:-:-:-:1      IADD     hRow,       rowOffset,  warpTid;
--:-:-:-:1      XMAD     storeIndex, hRow,       param_ldh, timeStep;
--:-:-:-:1      LEA      hAddr0.CC,  storeIndex, param_h[0],      4;
--:-:-:-:1      LEA.HI.X hAddr1,     storeIndex, param_h[1], RZ, 4;
--:-:-:-:1      LEA      lockAddr0,  timeStep,   param_lockAddr[0], 2;
--:-:-:-:1      LEA.HI.X lockAddr1,  timeStep,   param_lockAddr[1], RZ, 2;

//Conditional select for output
--:-:-:-:1  @P2 MOV output0, accum00;
--:-:-:-:1  @P3 MOV output0, accum04;
--:-:-:-:1  @P4 MOV output0, accum08;

--:-:-:-:1  @P2 MOV output1, accum01;
--:-:-:-:1  @P3 MOV output1, accum05;
--:-:-:-:1  @P4 MOV output1, accum09;

--:-:-:-:1  @P2 MOV output2, accum02;
--:-:-:-:1  @P3 MOV output2, accum06;
--:-:-:-:1  @P4 MOV output2, accum10;

--:-:-:-:1  @P2 MOV output3, accum03;
--:-:-:-:1  @P3 MOV output3, accum07;
--:-:-:-:3  @P4 MOV output3, accum11;

//Update timestep
--:-:-:-:1      ISETP.EQ.AND P5, PT, RZ, param_reverse, PT;
--:-:-:-:1  @P5 MOV setVal, 1;
--:-:-:-:1 @!P5 MOV setVal, -1;
--:-:-:-:1  @P5 MOV expectVal, param_seqLength;
--:-:-:-:1 @!P5 MOV expectVal, -1;
--:-:-:-:1      IADD timeStep, timeStep, setVal;
</SCHEDULE_BLOCK>

//Save select predicates
--:-:-:-:1      P2R predSave, PR, RZ, 0x0c;

--:-:-:-:1      MOV reluclip, param_reluclip;

//Add bias for output
--:-:-:-:1      FADD output0, output0, biasValue;
--:-:-:-:1      FADD output1, output1, biasValue;
--:-:-:-:1      FADD output2, output2, biasValue;
--:-:-:-:3      FADD output3, output3, biasValue;

//Accumulate on top of current data
20:-:-:-:1      FADD output0, output0, loadBuffer0;
--:-:-:-:1      FADD output1, output1, loadBuffer1;
--:-:-:-:1      FADD output2, output2, loadBuffer2;
--:-:-:-:3      FADD output3, output3, loadBuffer3;

//Activation function
//TODO: add others
--:-:-:-:2  FMNMX output0, output0, RZ, !PT;
--:-:-:-:2  FMNMX output1, output1, RZ, !PT;
--:-:-:-:2  FMNMX output2, output2, RZ, !PT;
--:-:-:-:2  FMNMX output3, output3, RZ, !PT;

--:-:-:-:2  FMNMX output0, output0, reluclip, PT;
--:-:-:-:2  FMNMX output1, output1, reluclip, PT;
--:-:-:-:2  FMNMX output2, output2, reluclip, PT;
--:-:-:-:2  FMNMX output3, output3, reluclip, PT;

//Conditional store
--:-:-:-:1  @P0 STG.E.CI.128 [hAddr], output;

//Compute predicate for time unrolling loop
--:-:-:Y:d      ISETP.NE.AND P5, PT, timeStep, expectVal, PT;

//P2 = (tid != 0)
//setVal = 1
--:-:-:-:1      ISETP.NE.AND P2, PT, tid, RZ, PT;
--:-:-:-:1      MOV expectVal, param_numBlks;
--:-:-:Y:b      MOV setVal, 1;

//Barrier for all blocks
--:-:-:-:f      MEMBAR.GL;
--:-:-:-:5      BAR.SYNC 0;

--:-:-:-:2      SSY SSY_TARGET1;
--:-:-:-:d  @P2 SYNC;

--:-:-:Y:2      ATOM.E.ADD RZ, [lockAddr], setVal;
--:-:-:-:d      SYNC;

SSY_TARGET1:
--:-:-:-:1      SSY SSY_TARGET2;
--:-:-:-:d  @P2 SYNC;

SPINLOCK:
--:-:1:Y:2      LDG.E lockVal, [lockAddr];
01:-:-:Y:d      ISETP.NE.AND P2, PT, lockVal, expectVal, PT;
--:-:-:-:5  @P2 BRA.U SPINLOCK;
--:-:-:-:d      SYNC;

SSY_TARGET2:
--:-:-:-:5      BAR.SYNC 0;

//Restore select predicates
--:-:-:-:1      R2P PR, predSave, 0x0c;

//Conditional branch back to beginning of loop
--:-:-:Y:5  @P5 BRA.U UNROLLING_LOOP;

--:-:-:-:5      EXIT;
